{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- [GaussianMixture 高斯混合模型概率分布](http://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html#sklearn.mixture.GaussianMixture)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import itertools\n",
    "from scipy import linalg\n",
    "import matplotlib as mpl\n",
    "import matplotlib.colors\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.mixture import GaussianMixture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 解决中文显示问题\n",
    "mpl.rcParams['font.sans-serif'] = [u'SimHei']\n",
    "mpl.rcParams['axes.unicode_minus'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 设置在jupyter中matplotlib的显示情况（默认inline是内嵌显示，通过设置为tk表示不内嵌显示）\n",
    "%matplotlib tk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def expand(a, b, rate=0.05):\n",
    "    d = (b - a) * rate\n",
    "    return a-d, b+d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## 样本数据产生\n",
    "n_samples = 500\n",
    "\n",
    "np.random.seed(28)\n",
    "C = np.array([[0., -0.1], [1.7, .4]])\n",
    "X = np.r_[np.dot(np.random.randn(n_samples, 2), C), .7 * np.random.randn(n_samples, 2) + np.array([-6, 3])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "均值:\n",
      " [[-5.96170456  3.0078179 ]\n",
      " [-0.03908874 -0.00864228]]\n",
      "方差:\n",
      " [[[ 0.52760145  0.02227373]\n",
      "  [ 0.02227373  0.49288209]]\n",
      "\n",
      " [[ 2.59795894  0.61525644]\n",
      "  [ 0.61525644  0.156022  ]]]\n"
     ]
    }
   ],
   "source": [
    "## 不同参数效果比较\n",
    "lowest_bic = np.infty\n",
    "bics = []\n",
    "n_components_range = range(1, 7)\n",
    "cv_types = ['spherical', 'tied', 'diag', 'full']\n",
    "\n",
    "for cv_type in cv_types:\n",
    "    for n_components in n_components_range:\n",
    "        gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type)\n",
    "        gmm.fit(X)\n",
    "        # 聚合的结果非常好，但是参数即方差很大，这是不合理的\n",
    "        # bic越小越好？\n",
    "        bics.append(gmm.bic(X))\n",
    "        if bics[-1] < lowest_bic:\n",
    "            lowest_bic = bics[-1]\n",
    "            best_gmm = gmm\n",
    "\n",
    "# 获取相关参数以及最优算法\n",
    "clf = best_gmm\n",
    "Y_ = clf.predict(X)\n",
    "\n",
    "print (\"均值:\\n\", clf.means_)\n",
    "print (\"方差:\\n\", clf.covariances_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 画图\n",
    "bics = np.array(bics)\n",
    "color_iter = itertools.cycle(['navy', 'turquoise', 'cornflowerblue', 'darkorange'])\n",
    "\n",
    "## 画出效果比较\n",
    "spl = plt.subplot(2, 1, 1)\n",
    "bars = []\n",
    "for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):\n",
    "    xpos = np.array(n_components_range) + .2 * (i - 2)\n",
    "    bars.append(plt.bar(xpos, bics[i * len(n_components_range):(i + 1) * len(n_components_range)], width=.2, color=color))\n",
    "\n",
    "plt.xticks(n_components_range)\n",
    "plt.ylim([bics.min() * 1.01 - .01 * bics.max(), bics.max()])\n",
    "plt.title(u'不同模型参数下BIC的值')\n",
    "xpos = np.mod(bics.argmin(), len(n_components_range)) + .65 + .2 * np.floor(bics.argmin() / len(n_components_range))\n",
    "plt.text(xpos, bics.min() * 0.97 + .03 * bics.max(), '*', fontsize=14)\n",
    "spl.set_xlabel(u'类别数量')\n",
    "spl.legend([b[0] for b in bars], cv_types)\n",
    "\n",
    "# 画出分类效果图（可以看到最优分类是2）\n",
    "splot = plt.subplot(2, 1, 2)\n",
    "cm_light = mpl.colors.ListedColormap(['#FFA0A0', '#A0FFA0'])\n",
    "cm_dark = mpl.colors.ListedColormap(['r', 'g'])\n",
    "\n",
    "x1_min, x1_max = X[:, 0].min(), X[:, 0].max()\n",
    "x2_min, x2_max = X[:, 1].min(), X[:, 1].max()\n",
    "x1_min, x1_max = expand(x1_min, x1_max)\n",
    "x2_min, x2_max = expand(x2_min, x2_max)\n",
    "x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]\n",
    "grid_test = np.stack((x1.flat, x2.flat), axis=1)\n",
    "grid_hat = clf.predict(grid_test)\n",
    "grid_hat = grid_hat.reshape(x1.shape)\n",
    "if clf.means_[0][0] > clf.means_[1][0]:\n",
    "    z = grid_hat == 0\n",
    "    grid_hat[z] = 1\n",
    "    grid_hat[~z] = 0\n",
    "\n",
    "plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)\n",
    "plt.scatter(X[:, 0], X[:, 1], s=30, c=Y_, marker='o', cmap=cm_dark, edgecolors='k')\n",
    "plt.xlim((x1_min, x1_max))\n",
    "plt.ylim((x2_min, x2_max))\n",
    "plt.title(u'最优参数下GMM算法效果')\n",
    "plt.subplots_adjust(hspace=.35, bottom=.02)\n",
    "plt.grid()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
